Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix/CI] Fix broken kernels/test_mha.py #12450

Merged
merged 3 commits into from
Jan 26, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
improve
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
  • Loading branch information
tlrmchlsmth committed Jan 26, 2025
commit dabdae56083e82ccf6e5d530b35bd7d30b376097
18 changes: 8 additions & 10 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
Expand Down Expand Up @@ -240,19 +243,14 @@ def forward(
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

# Expand key and value to match number of query heads
if self.num_kv_heads != self.num_heads:
assert self.num_heads % self.num_kv_heads == 0
key = key.repeat_interleave(self.num_heads //
self.num_kv_heads,
dim=2)
value = value.repeat_interleave(self.num_heads //
self.num_kv_heads,
dim=2)

out = xops.memory_efficient_attention_forward(query,
key,
value,
Expand Down
Loading