Skip to content

Enable SDPA without kv cache #8950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(

def forward(
self,
input_pos: torch.Tensor,
input_pos: Optional[torch.Tensor],
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
Expand Down Expand Up @@ -218,13 +218,17 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.head_dim,
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)
else:
# Use a constant state to avoid export error
self.zero_pos = torch.tensor([0])

self.SDPA = SDPA(
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

def forward(
self,
Expand Down Expand Up @@ -258,20 +262,8 @@ def forward(
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output), None

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

assert hasattr(self, "mask")

mask = self.mask[:seqlen, :seqlen]

output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)

return output, None
else:
mask = self.mask[:seqlen, :seqlen]
# No kv cache. Pass 0 input_pos
output = self.SDPA(self.zero_pos, q, k, v, bsz, seqlen, mask)
return self.wo(output), None
88 changes: 88 additions & 0 deletions examples/models/llama/tests/test_attention_sma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest

import torch
from executorch.examples.models.llama.attention import (
AttentionMHA,
KVCache,
ModelArgs,
Rope,
SDPA,
)


class TestAttentionMHA(unittest.TestCase):

def create_mock_args(self):
return ModelArgs(
use_kv_cache=True,
n_heads=8,
n_kv_heads=4,
head_dim=64,
max_batch_size=2,
max_context_len=16,
dim=512,
attention_qkv_bias=False,
enable_dynamic_shape=False,
)

def test_attentionmha_init(self):
args = self.create_mock_args()
rope = Rope(args)
attn = AttentionMHA(args, layer_id=0, rope=rope)

self.assertEqual(attn.n_heads, 8)
self.assertEqual(attn.n_kv_heads, 4)
self.assertEqual(attn.n_local_heads, 8)
self.assertEqual(attn.n_local_kv_heads, 4)
self.assertEqual(attn.head_dim, 64)
self.assertEqual(attn.dim, 512)
self.assertEqual(attn.mask.shape, (16, 16)) # Causal mask shape check
self.assertTrue(attn.use_kv_cache)

if attn.use_kv_cache:
self.assertIsInstance(attn.kv_cache, KVCache)
self.assertIsInstance(attn.SDPA, SDPA)

def test_attentionmha_forward(self):
args = self.create_mock_args()
rope = Rope(args)
attn = AttentionMHA(args, layer_id=0, rope=rope)

bsz, seqlen, dim = 2, 4, args.dim
x = torch.randn(bsz, seqlen, dim)
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
input_pos = torch.tensor([0, 1, 2, 3])

output, _ = attn.forward(x, freqs_cos, freqs_sin, input_pos=input_pos)

self.assertEqual(output.shape, (bsz, seqlen, dim))

def test_attentionmha_forward_no_kv_cache(self):
args = self.create_mock_args()
args.use_kv_cache = False # Disable KV cache for this test
rope = Rope(args)
attn = AttentionMHA(args, layer_id=0, rope=rope)

bsz, seqlen, dim = 2, 4, args.dim
x = torch.randn(bsz, seqlen, dim)
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
freqs_sin = torch.randn(seqlen, args.head_dim // 2)

output, _ = attn.forward(x, freqs_cos, freqs_sin)

self.assertEqual(output.shape, (bsz, seqlen, dim))

def test_attentionmha_invalid_kv_cache(self):
args = self.create_mock_args()
rope = Rope(args)
attn = AttentionMHA(args, layer_id=0, rope=rope)

bsz, seqlen, dim = 2, 4, args.dim
x = torch.randn(bsz, seqlen, dim)
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
freqs_sin = torch.randn(seqlen, args.head_dim // 2)

# No input_pos provided, should raise assertion error
with self.assertRaises(AssertionError):
attn.forward(x, freqs_cos, freqs_sin)
Loading