Skip to content

[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object #14390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)

from .backend import TestBackend

Expand All @@ -34,26 +34,20 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
use_per_token_if_dynamic=True)

def forward(self, x):
resid = torch.sqrt(x)
y = self.norm[0](x)

x2 = apply_fp8_linear(y,
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2,
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand Down
7 changes: 4 additions & 3 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import (
Expand Down Expand Up @@ -1057,6 +1057,7 @@ def __init__(
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention
self.fp8_linear_generic = Fp8LinearGenericOp()

# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
Expand All @@ -1071,7 +1072,7 @@ def __init__(
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
output_parallel = self.fp8_linear_generic.apply(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
Expand All @@ -1091,7 +1092,7 @@ def _v_up_proj_and_o_proj(self, x):
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
return self.fp8_linear_generic.apply(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand All @@ -24,7 +24,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -140,11 +140,8 @@ def apply_weights(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
22 changes: 9 additions & 13 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz)
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
Expand All @@ -37,6 +35,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = not current_platform.has_device_capability(89)
self.fp8_linear = Fp8LinearOp()

@classmethod
def get_name(cls) -> str:
Expand Down Expand Up @@ -73,7 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

def create_weights(
self,
Expand Down Expand Up @@ -159,12 +158,9 @@ def apply(self,
size_k=layer.input_size_per_partition,
bias=bias)

return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias)
21 changes: 10 additions & 11 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
Expand Down Expand Up @@ -137,7 +137,6 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
Expand All @@ -153,6 +152,10 @@ def __init__(self, quant_config: Fp8Config):
# Marlin doesn't support block-wise fp8
self.use_marlin = False

self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=cutlass_fp8_supported())

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -381,15 +384,11 @@ def apply(self,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)

return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=self.cutlass_fp8_supported)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)


class Fp8MoEMethod(FusedMoEMethodBase):
Expand Down
16 changes: 7 additions & 9 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)

Expand Down Expand Up @@ -95,7 +95,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.fp8_linear = Fp8LinearOp()

def create_weights(
self,
Expand Down Expand Up @@ -157,10 +157,8 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
18 changes: 9 additions & 9 deletions vllm/model_executor/layers/quantization/ptpc_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Fp8LinearOp)
from vllm.platforms import current_platform

ACTIVATION_SCHEMES = ["static", "dynamic"]
Expand Down Expand Up @@ -93,6 +93,8 @@ def __init__(self, quant_config: PTPCFp8Config):
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
use_per_token_if_dynamic=True)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
Expand All @@ -115,11 +117,9 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias,
cutlass_fp8_supported=False,
use_per_token_if_dynamic=True)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias)
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand All @@ -22,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -132,11 +131,8 @@ def apply_weights(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
Loading