Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 e4m3_fnuz support for rocm #2588

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 179 additions & 28 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from dataclasses import dataclass
from typing import Optional, Union, List
from typing import Optional, Tuple, Union, List
from loguru import logger

from text_generation_server.utils.import_utils import SYSTEM
Expand Down Expand Up @@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear


def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale


def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
Expand All @@ -62,8 +86,11 @@ def fp8_quantize(

# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
Expand All @@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should wire up scale at some point for CUDA as well.


return qweight, scale


Expand All @@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)

input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights also has _has_tensor maybe we should make it public and use it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for try: [...]get_tensor below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use the has_tensor

f"{prefix}.input_scale", to_dtype=False
).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand Down Expand Up @@ -125,9 +164,24 @@ def get_weights_col_packed(
)
scale = scale.reshape(-1).expand(w.shape[0])

input_scale = None
if weights.get_tensor(f"{prefix}.input_scale"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if weights.get_tensor(f"{prefix}.input_scale"):
if weights.has_tensor(f"{prefix}.input_scale"):

?

input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -154,9 +208,21 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
]
scale = torch.cat(scale, dim=0).reshape(-1)

input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this conditional, we probably need an assertion like

assert len(input_scale) == 0 or len(input_scale) == length(prefixes)

]
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -174,9 +240,16 @@ def get_weights_row(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -191,6 +264,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None

def get_linear(self, bias: torch.Tensor):
Expand All @@ -200,56 +274,99 @@ def get_linear(self, bias: torch.Tensor):
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
)


class Fp8Linear(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cleaner to have a separate Fp8LinearRocm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

_device_identity_cache = {}

def __init__(
self,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)

self.dtype = dtype
self.qweight = qweight
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
self.scale = scale.float()
self.input_scale = (
input_scale.float().reciprocal() if input_scale is not None else None
)

if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
else:
self.scale_upper_bound = scale_upper_bound

self.bias = bias if bias is not None else None

@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)

@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)

if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
)

@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]

def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
Expand All @@ -266,15 +383,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
return y.to(self.dtype)

qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)

per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1

if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)

if isinstance(output, tuple) and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)

output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]

output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias

output = output.to(dtype=self.dtype)

return output


Expand Down
9 changes: 8 additions & 1 deletion server/text_generation_server/layers/marlin/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ def from_unquant(cls, weight, bias, dtype):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
dtype: torch.dtype,
**kwargs,
):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)

def forward(self, A: torch.Tensor) -> torch.Tensor:
Expand Down
17 changes: 8 additions & 9 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,22 +342,19 @@ def get_model(
model_type = config_dict.get("model_type", None)

quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danieldk config renamed to quantisation config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have compression_config. Maybe with a comment that compression_config is for backwards compatibility?

if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
config_groups = quantization_config.get("config_groups", None)
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
elif method == "fbgemm_fp8":
elif method == "fbgemm_fp8" or method == "fp8":
log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
config_groups = compression_config.get("config_groups")
if config_groups is not None:
elif config_groups is not None:
# Compression config renamed to quantization_config
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
Expand All @@ -370,6 +367,8 @@ def get_model(
)
quantize = "fp8"
break
else:
log_master(logger.warning, f"Unknown quantization method {method}")

if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _get_slice(self, tensor_name: str):
slice_ = f.get_slice(tensor_name)
return slice_

def _has_tensor(self, tensor_name: str):
def has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
Expand Down
Loading