Skip to content

Commit

Permalink
Add support for scalar FP8 weight scales (huggingface#2550)
Browse files Browse the repository at this point in the history
* Add support for scalar FP8 weight scales

* Support LLM compressor FP8 checkpoints on H100

On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype.
However, we wouldn't pick up fp8 quantization for models quantized with
LLM compressor. This change adds enough parsing to detect if models have
FP8-quantized weights.

* Remove stray debug print
  • Loading branch information
danieldk authored and yuanwu2017 committed Oct 25, 2024
1 parent 68cfc94 commit 32d50c2
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str):

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -113,9 +115,16 @@ def get_weights_col_packed(

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
).reshape(-1)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]

# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)

# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
Expand Down Expand Up @@ -222,6 +239,9 @@ def from_unquant(cls, weight, bias, dtype):

@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
Expand Down Expand Up @@ -256,3 +276,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
bias=self.bias,
)
return output


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])

0 comments on commit 32d50c2

Please sign in to comment.