Skip to content

Commit c29dc89

Browse files
authored
Add support for scalar FP8 weight scales (#2550)
* Add support for scalar FP8 weight scales * Support LLM compressor FP8 checkpoints on H100 On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights. * Remove stray debug print
1 parent 0ff6ff6 commit c29dc89

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

server/text_generation_server/layers/fp8.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str):
8787

8888
if w.dtype == torch.float8_e4m3fn:
8989
# FP8 branch
90-
scale = weights.get_tensor(
91-
f"{prefix}.weight_scale", to_dtype=False
92-
).reshape(-1)
90+
scale = (
91+
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
92+
.reshape(-1)
93+
.expand(w.shape[0])
94+
)
9395
return Fp8Weight(
9496
weight=w,
9597
weight_scale=scale,
@@ -113,9 +115,16 @@ def get_weights_col_packed(
113115

114116
if w.dtype == torch.float8_e4m3fn:
115117
# FP8 branch
116-
scale = weights.get_packed_sharded(
117-
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
118-
).reshape(-1)
118+
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
119+
if scale.numel() > 1:
120+
scale = weights.get_packed_sharded(
121+
f"{prefix}.weight_scale",
122+
dim=0,
123+
block_sizes=block_sizes,
124+
to_dtype=False,
125+
)
126+
scale = scale.reshape(-1).expand(w.shape[0])
127+
119128
return Fp8Weight(
120129
weight=w,
121130
weight_scale=scale,
@@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
132141
w = [
133142
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
134143
]
144+
shapes = [x.shape for x in w]
145+
135146
# Concat then send to the device
136147
w = torch.cat(w, dim=dim).to(weights.device)
137148

138149
# FP8 branch
139150
if w.dtype == torch.float8_e4m3fn:
140151
scale = [
141-
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
142-
for p in prefixes
152+
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
153+
for p, shape in zip(prefixes, shapes)
143154
]
144155
scale = torch.cat(scale, dim=0).reshape(-1)
156+
145157
return Fp8Weight(
146158
weight=w,
147159
weight_scale=scale,
@@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str):
157169
w = weights.get_sharded(f"{prefix}.weight", dim=1)
158170
# FP8 branch
159171
if w.dtype == torch.float8_e4m3fn:
160-
scale = weights.get_tensor(
161-
f"{prefix}.weight_scale", to_dtype=False
162-
).reshape(-1)
172+
scale = (
173+
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
174+
.reshape(-1)
175+
.expand(w.shape[0])
176+
)
163177
return Fp8Weight(
164178
weight=w,
165179
weight_scale=scale,
@@ -182,6 +196,9 @@ class Fp8Weight(Weight):
182196
def get_linear(self, bias: torch.Tensor):
183197
if self.weight_scale is None:
184198
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
199+
# This is not checked by the fbgemm kernels, but they require contiguous
200+
# memory. Can be non-contiguous when we e.g. expand from scalars.
201+
self.weight_scale = self.weight_scale.contiguous()
185202
return get_fp8_linear().from_fp8(
186203
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
187204
)
@@ -222,6 +239,9 @@ def from_unquant(cls, weight, bias, dtype):
222239

223240
@classmethod
224241
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
242+
if FBGEMM_DYN_AVAILABLE:
243+
# fbgemm needs float32 scales.
244+
scale = scale.float()
225245
return cls(
226246
qweight=weight,
227247
scale=scale,
@@ -256,3 +276,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
256276
bias=self.bias,
257277
)
258278
return output
279+
280+
281+
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
282+
scale = weights.get_tensor(prefix, to_dtype=False)
283+
if scale.numel() > 1:
284+
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
285+
return scale.reshape(-1).expand(shape[0])

server/text_generation_server/models/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def get_model(
334334
model_type = config_dict.get("model_type", None)
335335

336336
quantization_config = config_dict.get("quantization_config", None)
337+
compression_config = config_dict.get("compression_config", None)
337338
if quantization_config is not None and quantize is None:
338339
method = quantization_config.get("quant_method", None)
339340
if method in {"gptq", "awq", "exl2"}:
@@ -344,6 +345,23 @@ def get_model(
344345
quantize = "fp8"
345346
else:
346347
log_master(logger.warning, f"Unknown quantization method {method}")
348+
elif compression_config is not None:
349+
# TODO: at some point we should probably fully parse the compression
350+
# configuration to know which parameters are compressed.
351+
config_groups = compression_config.get("config_groups")
352+
if config_groups is not None:
353+
for _, group in config_groups.items():
354+
weights_config = group.get("weights")
355+
if weights_config is not None:
356+
if (
357+
weights_config["type"] == "float"
358+
and weights_config["num_bits"] == 8
359+
):
360+
log_master(
361+
logger.info, "Auto selecting quantization method fp8"
362+
)
363+
quantize = "fp8"
364+
break
347365

348366
if dtype is None:
349367
if quantize in ["awq", "exl2", "gptq", "marlin"]:
@@ -768,7 +786,6 @@ def get_model(
768786
)
769787

770788
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
771-
print(f">>> model_type: {model_type}")
772789
if FLASH_ATTENTION:
773790
return FlashCausalLM(
774791
model_id=model_id,

0 commit comments

Comments
 (0)