Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions diffsynth_engine/utils/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,33 @@ def fp8_linear(
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
scale_a = 1.0
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])

x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if DTYPE_FP8 == torch.float8_e4m3fnuz:
scale_a = 2.0
input = input / scale_a
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).float().to(device=device)
input = input / scale_a
input = input.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

if len(input.shape) > 2:
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
result = torch._scaled_mm(
input,
weight.T,
scale_a=torch.tensor(scale_a).to(device=device),
scale_b=torch.tensor(1.0).to(device=device),
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
else:
result = torch._scaled_mm(
input,
weight.T,
scale_a=torch.tensor(scale_a).to(device=device),
scale_b=torch.tensor(1.0).to(device=device),
bias=bias,
out_dtype=origin_dtype,
)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result

F.linear = fp8_linear
Expand Down