Skip to content

Commit

Permalink
Fp8 e4m3_fnuz support for rocm (#2588)
Browse files Browse the repository at this point in the history
* (feat) fp8 fnuz support for rocm

* (review comments) Fix compression_config load, type hints

* (bug) update all has_tensor

* (review_comments) fix typo and added comments

* (nit) improved comment
  • Loading branch information
mht-sharma authored Oct 16, 2024
1 parent ffe05cc commit 704a58c
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 37 deletions.
208 changes: 180 additions & 28 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from dataclasses import dataclass
from typing import Optional, Union, List
from typing import Optional, Tuple, Union, List
from loguru import logger

from text_generation_server.utils.import_utils import SYSTEM
Expand Down Expand Up @@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear


def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale


def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
Expand All @@ -62,8 +86,11 @@ def fp8_quantize(

# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
Expand All @@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

return qweight, scale


Expand All @@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)

input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand Down Expand Up @@ -125,9 +164,24 @@ def get_weights_col_packed(
)
scale = scale.reshape(-1).expand(w.shape[0])

input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -154,9 +208,22 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
]
scale = torch.cat(scale, dim=0).reshape(-1)

input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -174,9 +241,16 @@ def get_weights_row(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -191,6 +265,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None

def get_linear(self, bias: torch.Tensor):
Expand All @@ -200,56 +275,99 @@ def get_linear(self, bias: torch.Tensor):
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
)


class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}

def __init__(
self,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)

self.dtype = dtype
self.qweight = qweight
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
self.scale = scale.float()
self.input_scale = (
input_scale.float().reciprocal() if input_scale is not None else None
)

if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
else:
self.scale_upper_bound = scale_upper_bound

self.bias = bias if bias is not None else None

@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)

@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)

if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
)

@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]

def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
Expand All @@ -266,15 +384,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
return y.to(self.dtype)

qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)

per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1

if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)

if isinstance(output, tuple) and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)

output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]

output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias

output = output.to(dtype=self.dtype)

return output


Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,15 @@ def get_weights_row(self, weights: Weights, prefix: str):
)

def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
Expand Down
9 changes: 8 additions & 1 deletion server/text_generation_server/layers/marlin/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ def from_unquant(cls, weight, bias, dtype):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
dtype: torch.dtype,
**kwargs,
):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)

def forward(self, A: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/marlin/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,15 @@ def get_weights_row(self, weights: Weights, prefix: str):
)

def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
Expand Down
Loading

0 comments on commit 704a58c

Please sign in to comment.