Skip to content

Commit

Permalink
revert disable flash operator on rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
tenpercent committed Aug 16, 2024
1 parent 7d21800 commit d6b6456
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,10 +607,7 @@ class FwOp(AttentionFwOpBase):
implementation.
"""

if torch.version.hip:
OPERATOR = None
else:
OPERATOR = get_operator("xformers_flash", "flash_fwd")
OPERATOR = get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
Expand Down Expand Up @@ -812,10 +809,7 @@ def operator_flop(
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

if torch.version.hip:
OPERATOR = None
else:
OPERATOR = get_operator("xformers_flash", "flash_bwd")
OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
Expand Down

0 comments on commit d6b6456

Please sign in to comment.