Skip to content

Commit e01a222

Browse files
committed
move
1 parent cbbadc6 commit e01a222

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

diffsynth_engine/utils/constants.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
import torch
2+
33

44
PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
55
REPO_ROOT = os.path.dirname(PACKAGE_ROOT)
@@ -34,11 +34,3 @@
3434
MB = 1024 * KB
3535
GB = 1024 * MB
3636
TB = 1024 * GB
37-
38-
# data type
39-
# AMD only support e4m3fnuz
40-
# https://onnx.ai/onnx/technical/float8.html
41-
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
42-
DTYPE_FP8 = torch.float8_e4m3fnuz
43-
else:
44-
DTYPE_FP8 = torch.float8_e4m3fn

diffsynth_engine/utils/fp8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from contextlib import contextmanager
5-
from diffsynth_engine.utils.constants import DTYPE_FP8
5+
from diffsynth_engine.utils.platform import DTYPE_FP8
66

77

88
def enable_fp8_linear(module: nn.Module):

diffsynth_engine/utils/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
# cross-platform definitions and utilities
12
import torch
23
import gc
34

4-
# 存放跨平台的工具类
5+
6+
# data type
7+
# AMD only supports float8_e4m3fnuz
8+
# https://onnx.ai/onnx/technical/float8.html
9+
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
10+
DTYPE_FP8 = torch.float8_e4m3fnuz
11+
else:
12+
DTYPE_FP8 = torch.float8_e4m3fn
513

614

715
def empty_cache():

0 commit comments

Comments
 (0)