Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions diffsynth_engine/utils/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager
from diffsynth_engine.utils.platform import DTYPE_FP8


def enable_fp8_linear(module: nn.Module):
Expand All @@ -12,7 +13,7 @@ def enable_fp8_linear(module: nn.Module):
def _enable_fp8_linear(module: nn.Module):
if isinstance(module, nn.Linear) and torch.is_floating_point(module.weight.data):
# avoid conversion for int weights like GGUF
module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
module.weight.data = module.weight.data.to(DTYPE_FP8)
for submodule in module.children():
_enable_fp8_linear(submodule)

Expand All @@ -32,16 +33,24 @@ def fp8_linear(
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
input = input.to(torch.float8_e4m3fn)
weight = weight.to(torch.float8_e4m3fn)
scale_a = 1.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if DTYPE_FP8 == torch.float8_e4m3fnuz:
scale_a = 2.0
input = input / scale_a
input = input.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

if len(input.shape) > 2:
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
result = torch._scaled_mm(
input,
weight.T,
scale_a=torch.tensor(1.0).to(device=device),
scale_a=torch.tensor(scale_a).to(device=device),
scale_b=torch.tensor(1.0).to(device=device),
bias=bias,
out_dtype=origin_dtype,
Expand All @@ -52,7 +61,7 @@ def fp8_linear(
result = torch._scaled_mm(
input,
weight.T,
scale_a=torch.tensor(1.0).to(device=device),
scale_a=torch.tensor(scale_a).to(device=device),
scale_b=torch.tensor(1.0).to(device=device),
bias=bias,
out_dtype=origin_dtype,
Expand Down
10 changes: 9 additions & 1 deletion diffsynth_engine/utils/platform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# cross-platform definitions and utilities
import torch
import gc

# 存放跨平台的工具类

# data type
# AMD only supports float8_e4m3fnuz
# https://onnx.ai/onnx/technical/float8.html
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
DTYPE_FP8 = torch.float8_e4m3fnuz
else:
DTYPE_FP8 = torch.float8_e4m3fn


def empty_cache():
Expand Down