From ccaf9ff5077934936bb8ef0a42449bca2c0ee1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 23 Sep 2024 15:46:41 +0000 Subject: [PATCH] Add support for scalar FP8 weight scales --- server/text_generation_server/layers/fp8.py | 46 ++++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 59b08b55bd2..2933aea2c96 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -113,9 +115,16 @@ def get_weights_col_packed( if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + scale = scale.reshape(-1).expand(w.shape[0]) + return Fp8Weight( weight=w, weight_scale=scale, @@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in w = [ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes ] + shapes = [x.shape for x in w] + # Concat then send to the device w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: scale = [ - weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) - for p in prefixes + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, @@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -182,6 +196,9 @@ class Fp8Weight(Weight): def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + # This is not checked by the fbgemm kernels, but they require contiguous + # memory. Can be non-contiguous when we e.g. expand from scalars. + self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype ) @@ -256,3 +273,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: bias=self.bias, ) return output + + +def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): + scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: + scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + return scale.reshape(-1).expand(shape[0])