@@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
8282 if backend_by_env_var is not None :
8383 selected_backend = backend_name_to_enum (backend_by_env_var )
8484 if selected_backend is None :
85- # For Volta and Turing GPUs, use xformers instead.
86- device_available = current_platform .has_device_capability (80 )
87- if device_available and support_fa :
88- from transformers .utils import is_flash_attn_2_available
89- if is_flash_attn_2_available ():
90- selected_backend = _Backend .FLASH_ATTN
85+ if current_platform .is_cuda ():
86+ device_available = current_platform .has_device_capability (80 )
87+ if device_available and support_fa :
88+ from transformers .utils import is_flash_attn_2_available
89+ if is_flash_attn_2_available ():
90+ selected_backend = _Backend .FLASH_ATTN
91+ else :
92+ logger .warning_once (
93+ "Current `vllm-flash-attn` has a bug inside vision "
94+ "module, so we use xformers backend instead. You can "
95+ "run `pip install flash-attn` to use flash-attention "
96+ "backend." )
97+ selected_backend = _Backend .XFORMERS
9198 else :
92- logger .warning_once (
93- "Current `vllm-flash-attn` has a bug inside vision module, "
94- "so we use xformers backend instead. You can run "
95- "`pip install flash-attn` to use flash-attention backend." )
99+ # For Volta and Turing GPUs, use xformers instead.
96100 selected_backend = _Backend .XFORMERS
97- elif current_platform .is_cpu () or current_platform .is_rocm ():
98- # ROCM doesn't support xformers
99- selected_backend = _Backend .TORCH_SDPA
100101 else :
101- selected_backend = _Backend .XFORMERS
102+ # Default to torch SDPA for other non-GPU platforms.
103+ selected_backend = _Backend .TORCH_SDPA
102104 return selected_backend
103105
104106
0 commit comments