Skip to content

Commit

Permalink
fix(l4): fix fp8 logic on l4 (#2277)
Browse files Browse the repository at this point in the history
* fix(l4): fix fp8 logic on l4

* also quant weights with single scale

* use marlin even on 89
  • Loading branch information
OlivierDehaene authored Jul 23, 2024
1 parent abc3253 commit 5fca30e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
"""

if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
major, _ = torch.cuda.get_device_capability()
if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

return GPTQMarlinFP8Linear
Expand All @@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear


def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
if FBGEMM_DYN_AVAILABLE:
def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
Expand Down Expand Up @@ -186,6 +188,9 @@ def __init__(
dtype,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")

self.dtype = dtype
self.qweight = qweight
self.scale = scale
Expand All @@ -201,7 +206,7 @@ def __init__(

@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight)
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
Expand Down Expand Up @@ -232,7 +237,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
return y.to(self.dtype)

qinput, scale = fp8_quantize(input)
qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
Expand Down

0 comments on commit 5fca30e

Please sign in to comment.