We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cbbadc6 commit 885b030Copy full SHA for 885b030
diffsynth_engine/utils/fp8_linear.py
@@ -2,7 +2,7 @@
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from contextlib import contextmanager
5
-from diffsynth_engine.utils.constants import DTYPE_FP8
+from diffsynth_engine.utils.platform import DTYPE_FP8
6
7
8
def enable_fp8_linear(module: nn.Module):
diffsynth_engine/utils/platform.py
@@ -1,7 +1,15 @@
1
+# cross-platform definitions and utilities
import torch
import gc
-# 存放跨平台的工具类
+
+# data type
+# AMD only supports float8_e4m3fnuz
+# https://onnx.ai/onnx/technical/float8.html
9
+if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
10
+ DTYPE_FP8 = torch.float8_e4m3fnuz
11
+else:
12
+ DTYPE_FP8 = torch.float8_e4m3fn
13
14
15
def empty_cache():
0 commit comments