-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Kernel][FP8] Initial support with dynamic per-tensor scaling #4118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
bf19419
add fp8
comaniac 65a9999
work
comaniac ad07afe
lint
comaniac f2ff3e5
comments
comaniac dbe46a6
done
comaniac 4a8d923
revert
comaniac 4bd55a5
comment
comaniac 2dee644
wip
comaniac 6a0a8ba
work
comaniac 7974ccf
lint
comaniac ca416f3
done
comaniac ee69c1b
fix ci
comaniac 2088110
fix
comaniac e7b2fc8
rename
comaniac 1f95739
comment
comaniac 5b70cd1
fix test
pcmoritz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Tests whether FP8 computation is enabled correctly. | ||
|
||
Run `pytest tests/quantization/test_fp8.py --forked`. | ||
""" | ||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS | ||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod | ||
|
||
capability = torch.cuda.get_device_capability() | ||
capability = capability[0] * 10 + capability[1] | ||
|
||
|
||
@pytest.mark.skipif( | ||
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), | ||
reason="FP8 is not supported on this GPU type.") | ||
def test_load_fp16_model(vllm_runner) -> None: | ||
llm = vllm_runner("facebook/opt-125m", quantization="fp8") | ||
|
||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model | ||
fc1 = model.model.decoder.layers[0].fc1 | ||
assert isinstance(fc1.linear_method, Fp8LinearMethod) | ||
assert fc1.weight.dtype == torch.float8_e4m3fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from torch.nn import Module | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm.model_executor.layers.linear import (LinearMethodBase, | ||
set_weight_attrs) | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
|
||
|
||
class FP8Config(QuantizationConfig): | ||
"""Config class for FP8.""" | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "fp8" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.bfloat16, torch.half] | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
# TODO: PyTorch 2.3.0+ is required to run FP8 on | ||
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to | ||
# be included: https://github.com/pytorch/pytorch/pull/118881 | ||
return 90 | ||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
return [] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "FP8Config": | ||
return cls() | ||
|
||
def get_linear_method(self) -> "Fp8LinearMethod": | ||
return Fp8LinearMethod(self) | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
|
||
class Fp8LinearMethod(LinearMethodBase): | ||
"""Linear method for FP8. | ||
We now support common FP16/BF16 model checkpoints ONLY. The weight | ||
scaling factor will be initialized after the model weights are loaded. | ||
|
||
Limitations: | ||
1. Only support per-tensor quantization due to torch._scaled_mm support. | ||
2. Only support float8_e4m3fn data type due to the limitation of | ||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) | ||
|
||
Args: | ||
quant_config: The quantization config. | ||
""" | ||
|
||
def __init__(self, quant_config: FP8Config): | ||
self.quant_config = quant_config | ||
|
||
def create_weights( | ||
self, | ||
layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_size_per_partition: int, | ||
input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype, | ||
**extra_weight_attrs, | ||
): | ||
weight = Parameter(torch.empty(output_size_per_partition, | ||
input_size_per_partition, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
layer.register_parameter("weight", weight) | ||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) | ||
set_weight_attrs(weight, extra_weight_attrs) | ||
|
||
w_scale = Parameter( | ||
comaniac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch.empty(1, dtype=torch.float32), | ||
requires_grad=False, | ||
) | ||
layer.register_parameter("weight_scaling_factor", w_scale) | ||
|
||
def process_weights_after_loading(self, layer: Module) -> None: | ||
# Although the linear_method is propagated to all layers, | ||
# only linear layers invoke "create_weights". So we check | ||
# whether "weight_scaling_facor" is registered to determine | ||
# whether the layer is a linear layer that requires quantization. | ||
if not hasattr(layer, "weight_scaling_factor"): | ||
return | ||
comaniac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
qweight, weight_scale = per_tensor_quantize(layer.weight) | ||
# torch._scaled_mm requires column-major in the second | ||
# input (weight), so we transpose the quantized weight. | ||
layer.weight = Parameter(qweight.t(), requires_grad=False) | ||
layer.weight_scaling_factor.data.copy_(weight_scale) | ||
|
||
def apply_weights(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
qinput, x_scale = per_tensor_quantize(x) | ||
output, _ = torch._scaled_mm( | ||
qinput, | ||
layer.weight, | ||
out_dtype=x.dtype, | ||
scale_a=x_scale, | ||
scale_b=layer.weight_scaling_factor, | ||
bias=bias, | ||
) | ||
return output | ||
|
||
|
||
def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: | ||
"""Quantize a tensor using per-tensor static scaling factor. | ||
|
||
Args: | ||
tensor: The input tensor. | ||
""" | ||
finfo = torch.finfo(torch.float8_e4m3fn) | ||
# Calculate the scale as dtype max divided by absmax. | ||
# Since .abs() creates a new tensor, we use aminmax to get | ||
# the min and max first and then calculate the absmax. | ||
min_val, max_val = tensor.aminmax() | ||
amax = min_val.abs().max(max_val.abs()) | ||
scale = finfo.max / amax.clamp(min=1e-12) | ||
# scale and clamp the tensor to bring it to | ||
# the representative range of float8 data type | ||
# (as default cast is unsaturated) | ||
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) | ||
# Return both float8 data and the inverse scale (as float), | ||
# as both required as inputs to torch._scaled_mm | ||
qweight = qweight.to(torch.float8_e4m3fn) | ||
scale = scale.float().reciprocal() | ||
return qweight, scale |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.