Skip to content

Commit 885b030

Browse files
committed
move
1 parent cbbadc6 commit 885b030

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

diffsynth_engine/utils/fp8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from contextlib import contextmanager
5-
from diffsynth_engine.utils.constants import DTYPE_FP8
5+
from diffsynth_engine.utils.platform import DTYPE_FP8
66

77

88
def enable_fp8_linear(module: nn.Module):

diffsynth_engine/utils/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
# cross-platform definitions and utilities
12
import torch
23
import gc
34

4-
# 存放跨平台的工具类
5+
6+
# data type
7+
# AMD only supports float8_e4m3fnuz
8+
# https://onnx.ai/onnx/technical/float8.html
9+
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
10+
DTYPE_FP8 = torch.float8_e4m3fnuz
11+
else:
12+
DTYPE_FP8 = torch.float8_e4m3fn
513

614

715
def empty_cache():

0 commit comments

Comments
 (0)