diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index bf5a098969e..54550ee3bf5 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module: """ if SYSTEM == "cuda": - major, minor = torch.cuda.get_device_capability() - if major == 8 and minor < 9: + major, _ = torch.cuda.get_device_capability() + if major == 8: from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): - if FBGEMM_DYN_AVAILABLE: +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -186,6 +188,9 @@ def __init__( dtype, ) -> None: super().__init__() + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + self.dtype = dtype self.qweight = qweight self.scale = scale @@ -201,7 +206,7 @@ def __init__( @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight) + qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype ) @@ -232,7 +237,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input) + qinput, scale = fp8_quantize(input, scalar=True) output, _ = torch._scaled_mm( qinput, self.qweight.t(),