From d6b64568739952fd95bf4eb172d6fbbdd53964d1 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:05:42 +0000 Subject: [PATCH] revert disable flash operator on rocm --- xformers/ops/fmha/flash.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 14a8335ec1..49e708dc28 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -607,10 +607,7 @@ class FwOp(AttentionFwOpBase): implementation. """ - if torch.version.hip: - OPERATOR = None - else: - OPERATOR = get_operator("xformers_flash", "flash_fwd") + OPERATOR = get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} @@ -812,10 +809,7 @@ def operator_flop( class BwOp(AttentionBwOpBase): __doc__ = FwOp.__doc__ - if torch.version.hip: - OPERATOR = None - else: - OPERATOR = get_operator("xformers_flash", "flash_bwd") + OPERATOR = get_operator("xformers_flash", "flash_bwd") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES