-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 e4m3_fnuz support for rocm #2588
Changes from all commits
f772856
7a7cd5f
b2b5024
1de9627
689aa26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import torch | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Union, List | ||
from typing import Optional, Tuple, Union, List | ||
from loguru import logger | ||
|
||
from text_generation_server.utils.import_utils import SYSTEM | ||
|
@@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module: | |
return Fp8Linear | ||
|
||
|
||
def normalize_e4m3fn_to_e4m3fnuz( | ||
weight: torch.Tensor, | ||
weight_scale: torch.Tensor, | ||
input_scale: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||
assert weight.dtype == torch.float8_e4m3fn | ||
# The bits pattern 10000000(-128) represents zero in e4m3fn | ||
# but NaN in e4m3fnuz. So here we set it to 0. | ||
# https://onnx.ai/onnx/technical/float8.html | ||
weight_as_int8 = weight.view(torch.int8) | ||
ROCM_FP8_NAN_AS_INT = -128 | ||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 | ||
weight = weight_as_int8.view(torch.float8_e4m3fnuz) | ||
|
||
# For the same bits representation, e4m3fnuz value is half of | ||
# the e4m3fn value, so we should double the scaling factor to | ||
# get the same dequantized value. | ||
# https://onnx.ai/onnx/technical/float8.html | ||
weight_scale = weight_scale * 2.0 | ||
if input_scale is not None: | ||
input_scale = input_scale * 2.0 | ||
return weight, weight_scale, input_scale | ||
|
||
|
||
def fp8_quantize( | ||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False | ||
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False | ||
): | ||
if FBGEMM_DYN_AVAILABLE and not scalar: | ||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( | ||
|
@@ -62,8 +86,11 @@ def fp8_quantize( | |
|
||
# weight, scale = quant_weights(weight, torch.int8, False) | ||
finfo = torch.finfo(qdtype) | ||
# Calculate the scale as dtype max divided by absmax | ||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) | ||
|
||
if scale is None: | ||
# Calculate the scale as dtype max divided by absmax | ||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) | ||
|
||
# scale and clamp the tensor to bring it to | ||
# the representative range of float8 data type | ||
# (as default cast is unsaturated) | ||
|
@@ -72,6 +99,10 @@ def fp8_quantize( | |
# as both required as inputs to torch._scaled_mm | ||
qweight = qweight.to(qdtype) | ||
scale = scale.float().reciprocal() | ||
|
||
if SYSTEM == "rocm": | ||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) | ||
|
||
return qweight, scale | ||
|
||
|
||
|
@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str): | |
.reshape(-1) | ||
.expand(w.shape[0]) | ||
) | ||
|
||
input_scale = None | ||
if weights.has_tensor(f"{prefix}.input_scale"): | ||
input_scale = weights.get_tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to use the |
||
f"{prefix}.input_scale", to_dtype=False | ||
).reshape(-1) | ||
|
||
return Fp8Weight( | ||
weight=w, | ||
weight_scale=scale, | ||
input_scale=input_scale, | ||
activation_scale_ub=self.activation_scale_ub, | ||
dtype=weights.dtype, | ||
) | ||
|
@@ -125,9 +164,24 @@ def get_weights_col_packed( | |
) | ||
scale = scale.reshape(-1).expand(w.shape[0]) | ||
|
||
input_scale = None | ||
if weights.has_tensor(f"{prefix}.input_scale"): | ||
input_scale = weights.get_tensor( | ||
f"{prefix}.input_scale", to_dtype=False | ||
) | ||
if input_scale.numel() > 1: | ||
input_scale = weights.get_packed_sharded( | ||
f"{prefix}.input_scale", | ||
dim=0, | ||
block_sizes=block_sizes, | ||
to_dtype=False, | ||
) | ||
input_scale = input_scale.reshape(-1).max() | ||
|
||
return Fp8Weight( | ||
weight=w, | ||
weight_scale=scale, | ||
input_scale=input_scale, | ||
activation_scale_ub=self.activation_scale_ub, | ||
dtype=weights.dtype, | ||
) | ||
|
@@ -154,9 +208,22 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in | |
] | ||
scale = torch.cat(scale, dim=0).reshape(-1) | ||
|
||
input_scale = [ | ||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) | ||
for p, shape in zip(prefixes, shapes) | ||
if weights.has_tensor(f"{p}.input_scale") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given this conditional, we probably need an assertion like
|
||
] | ||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes) | ||
input_scale = ( | ||
torch.cat(input_scale, dim=0).reshape(-1).max() | ||
if len(input_scale) != 0 | ||
else None | ||
) | ||
|
||
return Fp8Weight( | ||
weight=w, | ||
weight_scale=scale, | ||
input_scale=input_scale, | ||
activation_scale_ub=self.activation_scale_ub, | ||
dtype=weights.dtype, | ||
) | ||
|
@@ -174,9 +241,16 @@ def get_weights_row(self, weights: "Weights", prefix: str): | |
.reshape(-1) | ||
.expand(w.shape[0]) | ||
) | ||
input_scale = None | ||
if weights.has_tensor(f"{prefix}.input_scale"): | ||
input_scale = weights.get_tensor( | ||
f"{prefix}.input_scale", to_dtype=False | ||
).reshape(-1) | ||
|
||
return Fp8Weight( | ||
weight=w, | ||
weight_scale=scale, | ||
input_scale=input_scale, | ||
activation_scale_ub=self.activation_scale_ub, | ||
dtype=weights.dtype, | ||
) | ||
|
@@ -191,6 +265,7 @@ class Fp8Weight(Weight): | |
weight: torch.Tensor | ||
dtype: torch.dtype | ||
weight_scale: Optional[torch.Tensor] = None | ||
input_scale: Optional[torch.Tensor] = None | ||
activation_scale_ub: Optional[float] = None | ||
|
||
def get_linear(self, bias: torch.Tensor): | ||
|
@@ -200,56 +275,99 @@ def get_linear(self, bias: torch.Tensor): | |
# memory. Can be non-contiguous when we e.g. expand from scalars. | ||
self.weight_scale = self.weight_scale.contiguous() | ||
return get_fp8_linear().from_fp8( | ||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype | ||
weight=self.weight, | ||
scale=self.weight_scale, | ||
dtype=self.dtype, | ||
bias=bias, | ||
input_scale=self.input_scale, | ||
scale_upper_bound=self.activation_scale_ub, | ||
) | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be cleaner to have a separate Fp8LinearRocm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason. |
||
_device_identity_cache = {} | ||
|
||
def __init__( | ||
self, | ||
qweight, | ||
scale, | ||
scale_upper_bound, | ||
bias, | ||
dtype, | ||
qweight: torch.Tensor, | ||
scale: torch.Tensor, | ||
dtype: torch.dtype, | ||
bias: Optional[torch.Tensor] = None, | ||
input_scale: Optional[torch.Tensor] = None, | ||
scale_upper_bound: Optional[float] = None, | ||
) -> None: | ||
super().__init__() | ||
if FBGEMM_MM_AVAILABLE: | ||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") | ||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: | ||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( | ||
weight=qweight, weight_scale=scale | ||
) | ||
|
||
self.dtype = dtype | ||
self.qweight = qweight | ||
self.scale = scale | ||
self.scale_upper_bound = ( | ||
torch.tensor( | ||
[scale_upper_bound], dtype=torch.float32, device=qweight.device | ||
) | ||
if scale_upper_bound is not None | ||
else None | ||
self.scale = scale.float() | ||
self.input_scale = ( | ||
input_scale.float().reciprocal() if input_scale is not None else None | ||
) | ||
|
||
if FBGEMM_MM_AVAILABLE: | ||
self.scale_upper_bound = ( | ||
torch.tensor( | ||
[scale_upper_bound], dtype=torch.float32, device=qweight.device | ||
) | ||
if scale_upper_bound is not None | ||
else None | ||
) | ||
else: | ||
self.scale_upper_bound = scale_upper_bound | ||
|
||
self.bias = bias if bias is not None else None | ||
|
||
@classmethod | ||
def from_unquant(cls, weight, bias, dtype): | ||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) | ||
return cls( | ||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype | ||
qweight=qweight, | ||
scale=scale, | ||
dtype=dtype, | ||
bias=bias, | ||
input_scale=None, | ||
scale_upper_bound=None, | ||
) | ||
|
||
@classmethod | ||
def from_fp8(cls, weight, scale, input_scale, bias, dtype): | ||
def from_fp8( | ||
cls, | ||
weight: torch.Tensor, | ||
scale: torch.Tensor, | ||
dtype: torch.dtype, | ||
bias: Optional[torch.Tensor] = None, | ||
**kwargs, | ||
) -> "Fp8Linear": | ||
input_scale = kwargs.get("input_scale", None) | ||
scale_upper_bound = kwargs.get("scale_upper_bound", None) | ||
|
||
if FBGEMM_DYN_AVAILABLE: | ||
# fbgemm needs float32 scales. | ||
scale = scale.float() | ||
return cls( | ||
qweight=weight, | ||
scale=scale, | ||
scale_upper_bound=input_scale, | ||
input_scale=input_scale, | ||
scale_upper_bound=scale_upper_bound, | ||
bias=bias, | ||
dtype=dtype, | ||
) | ||
|
||
@classmethod | ||
def get_shared_device_identity(cls, device): | ||
# Input scaling factors are no longer optional in _scaled_mm starting | ||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale | ||
if device not in cls._device_identity_cache: | ||
cls._device_identity_cache[device] = torch.ones(1, device=device) | ||
return cls._device_identity_cache[device] | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if FBGEMM_MM_AVAILABLE: | ||
qinput, scale = fp8_quantize( | ||
|
@@ -266,15 +384,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: | |
) | ||
return y.to(self.dtype) | ||
|
||
qinput, scale = fp8_quantize(input, scalar=True) | ||
output, _ = torch._scaled_mm( | ||
qinput, | ||
self.qweight.t(), | ||
out_dtype=self.dtype, | ||
scale_a=scale, | ||
scale_b=self.scale, | ||
bias=self.bias, | ||
qinput, scale = fp8_quantize( | ||
input, | ||
self.input_scale, | ||
scale_upper_bound=self.scale_upper_bound, | ||
scalar=True, | ||
) | ||
|
||
per_tensor_weights = self.scale.numel() == 1 | ||
per_tensor_activations = scale.numel() == 1 | ||
|
||
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations): | ||
output = torch._scaled_mm( | ||
qinput, | ||
self.qweight.t(), | ||
out_dtype=self.dtype, | ||
scale_a=scale, | ||
scale_b=self.scale, | ||
bias=self.bias, | ||
) | ||
|
||
if isinstance(output, tuple) and len(output) == 2: | ||
output = output[0] | ||
else: | ||
device_identity = None | ||
if SYSTEM == "rocm": | ||
device_identity = self.get_shared_device_identity(self.qweight.device) | ||
|
||
output = torch._scaled_mm( | ||
qinput, | ||
self.qweight.t(), | ||
scale_a=device_identity, | ||
scale_b=device_identity, | ||
out_dtype=torch.float32, | ||
) | ||
if isinstance(output, tuple) and len(output) == 2: | ||
output = output[0] | ||
|
||
output = output * scale * self.scale.t() | ||
if self.bias is not None: | ||
output = output + self.bias | ||
|
||
output = output.to(dtype=self.dtype) | ||
|
||
return output | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should wire up
scale
at some point for CUDA as well.