diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 59b08b55bd2..9a6282849dc 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.dim() == 0: + scale = scale.expand([w.shape[0]]) + else: + scale.reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -113,9 +115,17 @@ def get_weights_col_packed( if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.dim() == 0: + scale = scale.expand(w.shape[0]) + else: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, @@ -132,16 +142,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in w = [ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes ] + shapes = [x.shape for x in w] + # Concat then send to the device w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: scale = [ - weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) - for p in prefixes + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, @@ -157,9 +170,11 @@ def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.dim() == 0: + scale = scale.expand(w.shape[0]) + else: + scale = scale.reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -182,6 +197,9 @@ class Fp8Weight(Weight): def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + # This is not checked by the fbgemm kernels, but they require contiguous + # memory. Can be non-contiguous when we e.g. expand from scalars. + self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype ) @@ -256,3 +274,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: bias=self.bias, ) return output + + +def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): + scale = weights.get_tensor(prefix, to_dtype=False) + if scale.dim(): + return weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + else: + return scale.expand(shape[0])