Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support SDPA attention for Molmo vision backbone #9410

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support molmo on CPU
  • Loading branch information
Isotr0py committed Oct 16, 2024
commit 255f8b2a89624fec01f486d7e9eec4f95354c13d
9 changes: 7 additions & 2 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,7 +197,7 @@ def __init__(
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
Expand All @@ -213,7 +214,7 @@ def __init__(
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
elif selected_backend in [_Backend.XFORMERS, _Backend.TORCH_SDPA]:
self._use_flash_attn = False
else:
raise RuntimeError(
Expand Down Expand Up @@ -242,6 +243,10 @@ def forward(self,
if self._use_flash_attn:
from flash_attn import flash_attn_func
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
elif is_cpu():
xq, xk, xv = (rearrange(x, "b s h d -> b h s d") for x in (xq, xk, xv))
output = F.scaled_dot_product_attention(xq, xk, xv)
output = rearrange(output, "b h s d -> b s h d ")
else:
from xformers import ops as xops
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
Expand Down