Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scalar FP8 weight scales #2550

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str):

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -113,9 +115,16 @@ def get_weights_col_packed(

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
).reshape(-1)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]

# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)

# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
Expand Down Expand Up @@ -222,6 +239,9 @@ def from_unquant(cls, weight, bias, dtype):

@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
Expand Down Expand Up @@ -256,3 +276,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
bias=self.bias,
)
return output


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])
19 changes: 18 additions & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def get_model(
model_type = config_dict.get("model_type", None)

quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
Expand All @@ -344,6 +345,23 @@ def get_model(
quantize = "fp8"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
config_groups = compression_config.get("config_groups")
if config_groups is not None:
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break

if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
Expand Down Expand Up @@ -768,7 +786,6 @@ def get_model(
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
Expand Down
Loading