Skip to content

Commit

Permalink
Add support for scalar FP8 weight scales
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Sep 24, 2024
1 parent f478aa7 commit ccaf9ff
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str):

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -113,9 +115,16 @@ def get_weights_col_packed(

if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
).reshape(-1)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]

# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)

# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
Expand All @@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
Expand Down Expand Up @@ -256,3 +273,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
bias=self.bias,
)
return output


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])

0 comments on commit ccaf9ff

Please sign in to comment.