Skip to content

Commit c8674bc

Browse files
Enable RDNA4 pytorch attention on ROCm 7.0 and up. (#10332)
1 parent 3dfdcf6 commit c8674bc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

comfy/model_management.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,9 @@ def amd_min_version(device=None, min_rdna_version=0):
345345
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
346346
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
347347
ENABLE_PYTORCH_ATTENTION = True
348-
# if torch_version_numeric >= (2, 8):
349-
# if any((a in arch) for a in ["gfx1201"]):
350-
# ENABLE_PYTORCH_ATTENTION = True
348+
if rocm_version >= (7, 0):
349+
if any((a in arch) for a in ["gfx1201"]):
350+
ENABLE_PYTORCH_ATTENTION = True
351351
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
352352
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
353353
SUPPORT_FP8_OPS = True

0 commit comments

Comments
 (0)