-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 e4m3_fnuz support for rocm #2588
Changes from 2 commits
f772856
7a7cd5f
b2b5024
1de9627
689aa26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||
import torch | ||||||
|
||||||
from dataclasses import dataclass | ||||||
from typing import Optional, Union, List | ||||||
from typing import Optional, Tuple, Union, List | ||||||
from loguru import logger | ||||||
|
||||||
from text_generation_server.utils.import_utils import SYSTEM | ||||||
|
@@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module: | |||||
return Fp8Linear | ||||||
|
||||||
|
||||||
def normalize_e4m3fn_to_e4m3fnuz( | ||||||
weight: torch.Tensor, | ||||||
weight_scale: torch.Tensor, | ||||||
input_scale: Optional[torch.Tensor] = None, | ||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||||||
assert weight.dtype == torch.float8_e4m3fn | ||||||
# The bits pattern 10000000(-128) represents zero in e4m3fn | ||||||
# but NaN in e4m3fnuz. So here we set it to 0. | ||||||
# https://onnx.ai/onnx/technical/float8.html | ||||||
weight_as_int8 = weight.view(torch.int8) | ||||||
ROCM_FP8_NAN_AS_INT = -128 | ||||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 | ||||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz) | ||||||
|
||||||
# For the same bits representation, e4m3fnuz value is half of | ||||||
# the e4m3fn value, so we should double the scaling factor to | ||||||
# get the same dequantized value. | ||||||
# https://onnx.ai/onnx/technical/float8.html | ||||||
weight_scale = weight_scale * 2.0 | ||||||
if input_scale is not None: | ||||||
input_scale = input_scale * 2.0 | ||||||
return weight, weight_scale, input_scale | ||||||
|
||||||
|
||||||
def fp8_quantize( | ||||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False | ||||||
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False | ||||||
): | ||||||
if FBGEMM_DYN_AVAILABLE and not scalar: | ||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( | ||||||
|
@@ -62,8 +86,11 @@ def fp8_quantize( | |||||
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False) | ||||||
finfo = torch.finfo(qdtype) | ||||||
# Calculate the scale as dtype max divided by absmax | ||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) | ||||||
|
||||||
if scale is None: | ||||||
# Calculate the scale as dtype max divided by absmax | ||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) | ||||||
|
||||||
# scale and clamp the tensor to bring it to | ||||||
# the representative range of float8 data type | ||||||
# (as default cast is unsaturated) | ||||||
|
@@ -72,6 +99,10 @@ def fp8_quantize( | |||||
# as both required as inputs to torch._scaled_mm | ||||||
qweight = qweight.to(qdtype) | ||||||
scale = scale.float().reciprocal() | ||||||
|
||||||
if SYSTEM == "rocm": | ||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) | ||||||
|
||||||
return qweight, scale | ||||||
|
||||||
|
||||||
|
@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str): | |||||
.reshape(-1) | ||||||
.expand(w.shape[0]) | ||||||
) | ||||||
|
||||||
input_scale = None | ||||||
if weights.has_tensor(f"{prefix}.input_scale"): | ||||||
input_scale = weights.get_tensor( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to use the |
||||||
f"{prefix}.input_scale", to_dtype=False | ||||||
).reshape(-1) | ||||||
|
||||||
return Fp8Weight( | ||||||
weight=w, | ||||||
weight_scale=scale, | ||||||
input_scale=input_scale, | ||||||
activation_scale_ub=self.activation_scale_ub, | ||||||
dtype=weights.dtype, | ||||||
) | ||||||
|
@@ -125,9 +164,24 @@ def get_weights_col_packed( | |||||
) | ||||||
scale = scale.reshape(-1).expand(w.shape[0]) | ||||||
|
||||||
input_scale = None | ||||||
if weights.get_tensor(f"{prefix}.input_scale"): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
input_scale = weights.get_tensor( | ||||||
f"{prefix}.input_scale", to_dtype=False | ||||||
) | ||||||
if input_scale.numel() > 1: | ||||||
input_scale = weights.get_packed_sharded( | ||||||
f"{prefix}.input_scale", | ||||||
dim=0, | ||||||
block_sizes=block_sizes, | ||||||
to_dtype=False, | ||||||
) | ||||||
input_scale = input_scale.reshape(-1).max() | ||||||
|
||||||
return Fp8Weight( | ||||||
weight=w, | ||||||
weight_scale=scale, | ||||||
input_scale=input_scale, | ||||||
activation_scale_ub=self.activation_scale_ub, | ||||||
dtype=weights.dtype, | ||||||
) | ||||||
|
@@ -154,9 +208,21 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in | |||||
] | ||||||
scale = torch.cat(scale, dim=0).reshape(-1) | ||||||
|
||||||
input_scale = [ | ||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) | ||||||
for p, shape in zip(prefixes, shapes) | ||||||
if weights.has_tensor(f"{p}.input_scale") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given this conditional, we probably need an assertion like
|
||||||
] | ||||||
input_scale = ( | ||||||
torch.cat(input_scale, dim=0).reshape(-1).max() | ||||||
if len(input_scale) != 0 | ||||||
else None | ||||||
) | ||||||
|
||||||
return Fp8Weight( | ||||||
weight=w, | ||||||
weight_scale=scale, | ||||||
input_scale=input_scale, | ||||||
activation_scale_ub=self.activation_scale_ub, | ||||||
dtype=weights.dtype, | ||||||
) | ||||||
|
@@ -174,9 +240,16 @@ def get_weights_row(self, weights: "Weights", prefix: str): | |||||
.reshape(-1) | ||||||
.expand(w.shape[0]) | ||||||
) | ||||||
input_scale = None | ||||||
if weights.has_tensor(f"{prefix}.input_scale"): | ||||||
input_scale = weights.get_tensor( | ||||||
f"{prefix}.input_scale", to_dtype=False | ||||||
).reshape(-1) | ||||||
|
||||||
return Fp8Weight( | ||||||
weight=w, | ||||||
weight_scale=scale, | ||||||
input_scale=input_scale, | ||||||
activation_scale_ub=self.activation_scale_ub, | ||||||
dtype=weights.dtype, | ||||||
) | ||||||
|
@@ -191,6 +264,7 @@ class Fp8Weight(Weight): | |||||
weight: torch.Tensor | ||||||
dtype: torch.dtype | ||||||
weight_scale: Optional[torch.Tensor] = None | ||||||
input_scale: Optional[torch.Tensor] = None | ||||||
activation_scale_ub: Optional[float] = None | ||||||
|
||||||
def get_linear(self, bias: torch.Tensor): | ||||||
|
@@ -200,56 +274,99 @@ def get_linear(self, bias: torch.Tensor): | |||||
# memory. Can be non-contiguous when we e.g. expand from scalars. | ||||||
self.weight_scale = self.weight_scale.contiguous() | ||||||
return get_fp8_linear().from_fp8( | ||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype | ||||||
weight=self.weight, | ||||||
scale=self.weight_scale, | ||||||
dtype=self.dtype, | ||||||
bias=bias, | ||||||
input_scale=self.input_scale, | ||||||
scale_upper_bound=self.activation_scale_ub, | ||||||
) | ||||||
|
||||||
|
||||||
class Fp8Linear(torch.nn.Module): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be cleaner to have a separate Fp8LinearRocm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason. |
||||||
_device_identity_cache = {} | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
qweight, | ||||||
scale, | ||||||
scale_upper_bound, | ||||||
bias, | ||||||
dtype, | ||||||
qweight: torch.Tensor, | ||||||
scale: torch.Tensor, | ||||||
dtype: torch.dtype, | ||||||
bias: Optional[torch.Tensor] = None, | ||||||
input_scale: Optional[torch.Tensor] = None, | ||||||
scale_upper_bound: Optional[float] = None, | ||||||
) -> None: | ||||||
super().__init__() | ||||||
if FBGEMM_MM_AVAILABLE: | ||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") | ||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: | ||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( | ||||||
weight=qweight, weight_scale=scale | ||||||
) | ||||||
|
||||||
self.dtype = dtype | ||||||
self.qweight = qweight | ||||||
self.scale = scale | ||||||
self.scale_upper_bound = ( | ||||||
torch.tensor( | ||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device | ||||||
) | ||||||
if scale_upper_bound is not None | ||||||
else None | ||||||
self.scale = scale.float() | ||||||
self.input_scale = ( | ||||||
input_scale.float().reciprocal() if input_scale is not None else None | ||||||
) | ||||||
|
||||||
if FBGEMM_MM_AVAILABLE: | ||||||
self.scale_upper_bound = ( | ||||||
torch.tensor( | ||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device | ||||||
) | ||||||
if scale_upper_bound is not None | ||||||
else None | ||||||
) | ||||||
else: | ||||||
self.scale_upper_bound = scale_upper_bound | ||||||
|
||||||
self.bias = bias if bias is not None else None | ||||||
|
||||||
@classmethod | ||||||
def from_unquant(cls, weight, bias, dtype): | ||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) | ||||||
return cls( | ||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype | ||||||
qweight=qweight, | ||||||
scale=scale, | ||||||
dtype=dtype, | ||||||
bias=bias, | ||||||
input_scale=None, | ||||||
scale_upper_bound=None, | ||||||
) | ||||||
|
||||||
@classmethod | ||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype): | ||||||
def from_fp8( | ||||||
cls, | ||||||
weight: torch.Tensor, | ||||||
scale: torch.Tensor, | ||||||
dtype: torch.dtype, | ||||||
bias: Optional[torch.Tensor] = None, | ||||||
**kwargs, | ||||||
) -> "Fp8Linear": | ||||||
input_scale = kwargs.get("input_scale", None) | ||||||
scale_upper_bound = kwargs.get("scale_upper_bound", None) | ||||||
|
||||||
if FBGEMM_DYN_AVAILABLE: | ||||||
# fbgemm needs float32 scales. | ||||||
scale = scale.float() | ||||||
return cls( | ||||||
qweight=weight, | ||||||
scale=scale, | ||||||
scale_upper_bound=input_scale, | ||||||
input_scale=input_scale, | ||||||
scale_upper_bound=scale_upper_bound, | ||||||
bias=bias, | ||||||
dtype=dtype, | ||||||
) | ||||||
|
||||||
@classmethod | ||||||
def get_shared_device_identity(cls, device): | ||||||
# Input scaling factors are no longer optional in _scaled_mm starting | ||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale | ||||||
if device not in cls._device_identity_cache: | ||||||
cls._device_identity_cache[device] = torch.ones(1, device=device) | ||||||
return cls._device_identity_cache[device] | ||||||
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
if FBGEMM_MM_AVAILABLE: | ||||||
qinput, scale = fp8_quantize( | ||||||
|
@@ -266,15 +383,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: | |||||
) | ||||||
return y.to(self.dtype) | ||||||
|
||||||
qinput, scale = fp8_quantize(input, scalar=True) | ||||||
output, _ = torch._scaled_mm( | ||||||
qinput, | ||||||
self.qweight.t(), | ||||||
out_dtype=self.dtype, | ||||||
scale_a=scale, | ||||||
scale_b=self.scale, | ||||||
bias=self.bias, | ||||||
qinput, scale = fp8_quantize( | ||||||
input, | ||||||
self.input_scale, | ||||||
scale_upper_bound=self.scale_upper_bound, | ||||||
scalar=True, | ||||||
) | ||||||
|
||||||
per_tensor_weights = self.scale.numel() == 1 | ||||||
per_tensor_activations = scale.numel() == 1 | ||||||
|
||||||
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations): | ||||||
output = torch._scaled_mm( | ||||||
qinput, | ||||||
self.qweight.t(), | ||||||
out_dtype=self.dtype, | ||||||
scale_a=scale, | ||||||
scale_b=self.scale, | ||||||
bias=self.bias, | ||||||
) | ||||||
|
||||||
if isinstance(output, tuple) and len(output) == 2: | ||||||
output = output[0] | ||||||
else: | ||||||
device_identity = None | ||||||
if SYSTEM == "rocm": | ||||||
device_identity = self.get_shared_device_identity(self.qweight.device) | ||||||
|
||||||
output = torch._scaled_mm( | ||||||
qinput, | ||||||
self.qweight.t(), | ||||||
scale_a=device_identity, | ||||||
scale_b=device_identity, | ||||||
out_dtype=torch.float32, | ||||||
) | ||||||
if isinstance(output, tuple) and len(output) == 2: | ||||||
output = output[0] | ||||||
|
||||||
output = output * scale * self.scale.t() | ||||||
if self.bias is not None: | ||||||
output = output + self.bias | ||||||
|
||||||
output = output.to(dtype=self.dtype) | ||||||
|
||||||
return output | ||||||
|
||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,22 +342,19 @@ def get_model( | |
model_type = config_dict.get("model_type", None) | ||
|
||
quantization_config = config_dict.get("quantization_config", None) | ||
compression_config = config_dict.get("compression_config", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @danieldk config renamed to quantisation config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have |
||
if quantization_config is not None and quantize is None: | ||
method = quantization_config.get("quant_method", None) | ||
config_groups = quantization_config.get("config_groups", None) | ||
if method in {"gptq", "awq", "exl2"}: | ||
log_master(logger.info, f"Auto selecting quantization method {method}") | ||
quantize = method | ||
elif method == "fbgemm_fp8": | ||
elif method == "fbgemm_fp8" or method == "fp8": | ||
log_master(logger.info, "Auto selecting quantization method fp8") | ||
quantize = "fp8" | ||
else: | ||
log_master(logger.warning, f"Unknown quantization method {method}") | ||
elif compression_config is not None: | ||
# TODO: at some point we should probably fully parse the compression | ||
# configuration to know which parameters are compressed. | ||
config_groups = compression_config.get("config_groups") | ||
if config_groups is not None: | ||
elif config_groups is not None: | ||
# Compression config renamed to quantization_config | ||
# TODO: at some point we should probably fully parse the compression | ||
# configuration to know which parameters are compressed. | ||
for _, group in config_groups.items(): | ||
weights_config = group.get("weights") | ||
if weights_config is not None: | ||
|
@@ -370,6 +367,8 @@ def get_model( | |
) | ||
quantize = "fp8" | ||
break | ||
else: | ||
log_master(logger.warning, f"Unknown quantization method {method}") | ||
|
||
if dtype is None: | ||
if quantize in ["awq", "exl2", "gptq", "marlin"]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should wire up
scale
at some point for CUDA as well.